# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf

from config import num_class
from my_batch_norm import bn_layer_top

slim = tf.contrib.slim

_BATCH_NORM_DECAY = 0.9
_BATCH_NORM_EPSILON = 1e-05
_LEAKY_RELU = 0.1

_ANCHORS = [(10, 13), (16, 30), (33, 23),
            (30, 61), (62, 45), (59, 119),
            (116, 90), (156, 198), (373, 326)]


def darknet53(inputs):
    """
    Builds Darknet-53 model.
    """
    inputs = _conv2d_fixed_padding(inputs, 32, 3)
    inputs = _conv2d_fixed_padding(inputs, 64, 3, strides=2)
    inputs = _darknet53_block(inputs, 32)
    inputs = _conv2d_fixed_padding(inputs, 128, 3, strides=2)
    
    for i in range(2):
        inputs = _darknet53_block(inputs, 64)
    
    inputs = _conv2d_fixed_padding(inputs, 256, 3, strides=2)
    
    for i in range(4):
        inputs = _darknet53_block(inputs, 128)
    
    route_1 = inputs
    inputs = _conv2d_fixed_padding(inputs, 512, 3, strides=2)
    
    for i in range(4):
        inputs = _darknet53_block(inputs, 256)
    
    route_2 = inputs
    inputs = _conv2d_fixed_padding(inputs, 1024, 3, strides=2)
    
    for i in range(2):
        inputs = _darknet53_block(inputs, 512)
    
    return route_1, route_2, inputs


def _conv2d_fixed_padding(inputs, filters, kernel_size, strides=1):
    if strides > 1:
        inputs = _fixed_padding(inputs, kernel_size)
    inputs = slim.conv2d(inputs, filters, kernel_size, stride=strides,
                         padding=('SAME' if strides == 1 else 'VALID'))
    return inputs


def _darknet53_block(inputs, filters):
    shortcut = inputs
    inputs = _conv2d_fixed_padding(inputs, filters, 1)
    inputs = _conv2d_fixed_padding(inputs, filters * 2, 3)
    
    inputs = inputs + shortcut
    return inputs


@tf.contrib.framework.add_arg_scope
def _fixed_padding(inputs, kernel_size, *args, mode='CONSTANT', **kwargs):
    """
    Pads the input along the spatial dimensions independently of input size.

    Args:
      inputs: A tensor of size [batch, channels, height_in, width_in] or
        [batch, height_in, width_in, channels] depending on data_format.
      kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
                   Should be a positive integer.
      data_format: The input format ('NHWC' or 'NCHW').
      mode: The mode for tf.pad.

    Returns:
      A tensor with the same format as the input with the data either intact
      (if kernel_size == 1) or padded (if kernel_size > 1).
    """
    pad_total = kernel_size - 1
    pad_beg = pad_total // 2
    pad_end = pad_total - pad_beg
    
    if kwargs['data_format'] == 'NCHW':
        padded_inputs = tf.pad(inputs, [[0, 0], [0, 0],
                                        [pad_beg, pad_end],
                                        [pad_beg, pad_end]],
                               mode=mode)
    else:
        padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
                                        [pad_beg, pad_end], [0, 0]], mode=mode)
    return padded_inputs


def _yolo_block(inputs, filters):
    inputs = _conv2d_fixed_padding(inputs, filters, 1)
    inputs = _conv2d_fixed_padding(inputs, filters * 2, 3)
    inputs = _conv2d_fixed_padding(inputs, filters, 1)
    route = inputs
    inputs = _conv2d_fixed_padding(inputs, filters * 2, 3)
    return route, inputs


def _get_size(shape, data_format):
    if len(shape) == 4:
        shape = shape[1:]
    return shape[1:3] if data_format == 'NCHW' else shape[0:2]


def _upsample(inputs, out_shape, data_format='NCHW'):
    # tf.image.resize_nearest_neighbor accepts input in format NHWC
    if data_format == 'NCHW':
        inputs = tf.transpose(inputs, [0, 2, 3, 1])
    
    if data_format == 'NCHW':
        new_height = out_shape[3]
        new_width = out_shape[2]
    else:
        new_height = out_shape[2]
        new_width = out_shape[1]
    
    inputs = tf.image.resize_nearest_neighbor(inputs, (new_height, new_width))
    
    # back to NCHW if needed
    if data_format == 'NCHW':
        inputs = tf.transpose(inputs, [0, 3, 1, 2])
    
    inputs = tf.identity(inputs, name='upsampled')
    return inputs


def yolo_v3(inputs, is_training, data_format='NHWC', reuse=False):
    """
    Creates YOLO v3 model.

    :param inputs: a 4-D tensor of size [batch_size, height, width, channels].
        Dimension batch_size may be undefined. The channel order is RGB.
    :param num_classes: number of predicted classes.
    :param is_training: whether is training or not.
    :param data_format: data format NCHW or NHWC.
    :param reuse: whether or not the network and its variables should be reused.
    :return:
    """
    # it will be needed later on
    img_size = inputs.get_shape().as_list()[1:3]
    
    # transpose the inputs to NCHW
    if data_format == 'NCHW':
        inputs = tf.transpose(inputs, [0, 3, 1, 2])
    
    # normalize values to range [0..1]
    inputs = inputs / 255
    
    # Set activation_fn and parameters for conv2d, batch_norm.
    with slim.arg_scope([slim.conv2d, _fixed_padding], data_format=data_format, reuse=reuse):
        with slim.arg_scope([slim.conv2d],
                            normalizer_fn=bn_layer_top,
                            biases_initializer=None,
                            activation_fn=lambda x: tf.nn.leaky_relu(x, alpha=_LEAKY_RELU)):
            with slim.arg_scope([bn_layer_top],
                                is_training=is_training):
                with tf.variable_scope('darknet-53'):
                    # 52 26 13
                    route_1, route_2, inputs = darknet53(inputs)
                
                with tf.variable_scope('yolo-v3'):
                    # 512 1024
                    route, inputs = _yolo_block(inputs, 512)
                    # 13 13 15
                    raw1 = slim.conv2d(inputs, 3 * (7 + num_class), 1,
                                       stride=1, normalizer_fn=None,
                                       activation_fn=None,
                                       biases_initializer=tf.zeros_initializer())
                    
                    inputs = _conv2d_fixed_padding(route, 256, 1)
                    upsample_size = route_2.get_shape().as_list()
                    inputs = _upsample(inputs, upsample_size, data_format)
                    inputs = tf.concat([inputs, route_2],
                                       axis=1 if data_format == 'NCHW' else 3)
                    # 256 512
                    route, inputs = _yolo_block(inputs, 256)
                    
                    # 26 26 15
                    raw2 = slim.conv2d(inputs, 3 * (7 + num_class), 1,
                                       stride=1, normalizer_fn=None,
                                       activation_fn=None,
                                       biases_initializer=tf.zeros_initializer())
                    
                    inputs = _conv2d_fixed_padding(route, 128, 1)
                    upsample_size = route_1.get_shape().as_list()
                    inputs = _upsample(inputs, upsample_size, data_format)
                    inputs = tf.concat([inputs, route_1],
                                       axis=1 if data_format == 'NCHW' else 3)
                    # 128 256
                    _, inputs = _yolo_block(inputs, 128)
                    
                    # 52 52 15
                    raw3 = slim.conv2d(inputs, 3 * (7 + num_class), 1,
                                       stride=1, normalizer_fn=None,
                                       activation_fn=None,
                                       biases_initializer=tf.zeros_initializer())
                    
                    return raw1, raw2, raw3
